package com.pai4j.search.service.voctor;

import com.google.common.collect.Lists;
import com.google.gson.Gson;
import com.google.gson.JsonObject;
import com.pai4j.common.constants.MilvusConstants;
import com.pai4j.common.enums.EmbeddingModelEnum;
import com.pai4j.common.util.HtmlSplitter;
import com.pai4j.common.util.JsonUtil;
import com.pai4j.domain.dto.search.SyncSearchEngineDTO;
import com.pai4j.domain.dto.search.SyncSearchEngineListDTO;
import com.pai4j.domain.vo.response.search.ArticleMilvusSearchResponse;
import io.milvus.v2.common.DataType;
import io.milvus.v2.common.IndexParam;
import io.milvus.v2.service.collection.request.AddFieldReq;
import io.milvus.v2.service.collection.request.CreateCollectionReq;
import io.milvus.v2.service.vector.request.AnnSearchReq;
import io.milvus.v2.service.vector.request.HybridSearchReq;
import io.milvus.v2.service.vector.request.InsertReq;
import io.milvus.v2.service.vector.request.data.FloatVec;
import io.milvus.v2.service.vector.request.ranker.WeightedRanker;
import io.milvus.v2.service.vector.response.InsertResp;
import io.milvus.v2.service.vector.response.SearchResp;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.springframework.stereotype.Service;

import java.util.*;
import java.util.concurrent.ArrayBlockingQueue;


@Slf4j
@Service
public class ArticleMilvusService extends AbstractMilvusQueueService<SyncSearchEngineListDTO, ArticleMilvusSearchResponse, SyncSearchEngineDTO> {


    /**
     * 本地队列初始长度
     */
    private static final int CAPACITY = 2000;

    /**
     * 本地队列，用于同步数据入向量库
     */
    public ArrayBlockingQueue<SyncSearchEngineDTO> queue = new ArrayBlockingQueue(CAPACITY);

    /**
     * 纬度，跟embedding模型纬度保持一致
     * @return
     */
    @Override
    protected int dimension() {
        return EmbeddingModelEnum.ALIBABA.getDimension();
    }

    /**
     * 集合定义
     * @return
     */
    @Override
    protected String collectionName() {

        return MilvusConstants.CollectionName.COLLECTION_ARTICLE.name();
    }

    /**
     * schema 定义
     *
     * @return
     */
    @Override
    protected CreateCollectionReq.CollectionSchema schema() {

        CreateCollectionReq.CollectionSchema schema = client.createSchema();
        schema.addField(AddFieldReq.builder()
                .fieldName("id")
                .dataType(DataType.VarChar)
                .isPrimaryKey(true)
                .autoID(true)
                .build());

        /**
         * 业务字段
         */
        schema.addField(AddFieldReq.builder()
                .fieldName("pk")
                .dataType(DataType.VarChar)
                .maxLength(64)
                .build());
        schema.addField(AddFieldReq.builder()
                .fieldName("source_type")
                .dataType(DataType.VarChar)
                .maxLength(64)
                .build());
        schema.addField(AddFieldReq.builder()
                .fieldName("summary")
                .dataType(DataType.VarChar)
                .maxLength(500)
                .build());
        schema.addField(AddFieldReq.builder()
                .fieldName("content")
                .dataType(DataType.VarChar)
                .maxLength(65535)
                .build());

        /**
         * 向量字段
         */
        schema.addField(AddFieldReq.builder()
                .fieldName("summary_vector")
                .dataType(DataType.FloatVector)
                .dimension(this.dimension())
                .build());
        schema.addField(AddFieldReq.builder()
                .fieldName("content_vector")
                .dataType(DataType.FloatVector)
                .dimension(this.dimension())
                .build());
        return schema;
    }

    /**
     * 索引定义
     * @return
     */
    @Override
    protected List<IndexParam> indexParams() {

        IndexParam summaryVectorIndex = IndexParam.builder()
                .fieldName("summary_vector")
                // 余弦相似度
                .metricType(IndexParam.MetricType.COSINE)
                .build();
        IndexParam contentVectorIndex = IndexParam.builder()
                .fieldName("content_vector")
                // 余弦相似度
                .metricType(IndexParam.MetricType.COSINE)
                .build();

        return List.of(summaryVectorIndex, contentVectorIndex);
    }

    /**
     * 入队业务逻辑实现，实际入队由父级putTaskQueue完成。
     * @param syncData
     * @return
     */
    @Override
    protected boolean addTaskQueue(SyncSearchEngineListDTO syncData) {
        if (CollectionUtils.isEmpty(syncData.getDataList())) {
            return true;
        }
        // 文章同步任务加入队列
        syncData.getDataList().forEach(super::putTaskQueue);
        return true;
    }

    /**
     * 基于相似度的topK条数据搜索
     *
     * 混合多向量搜索
     *
     * @param query
     * @param topK
     * @return
     */
    @Override
    public List<ArticleMilvusSearchResponse> search(String query, int topK, float score) {

        // 对用户query做embedding：这里embedding模型需要同同原数据embed采用同模型
        float[] queryVector = super.embed(query);
        AnnSearchReq summaryReq = AnnSearchReq.builder()
                // 采用相似度检索算法：余弦相似度
                .metricType(IndexParam.MetricType.COSINE)
                .vectorFieldName("summary_vector")
                .topK(topK)
                .vectors(Collections.singletonList(new FloatVec(queryVector)))
                .build();
        AnnSearchReq contentReq = AnnSearchReq.builder()
                // 采用相似度检索算法：余弦相似度
                .metricType(IndexParam.MetricType.COSINE)
                .vectorFieldName("content_vector")
                .topK(topK)
                .vectors(Collections.singletonList(new FloatVec(queryVector)))
                .build();
        HybridSearchReq request = HybridSearchReq.builder()
                .collectionName(this.collectionName())
                .searchRequests(Lists.newArrayList(summaryReq, contentReq))
                .ranker(new WeightedRanker(Arrays.asList(0.5f, 0.5f)))
                .topK(topK)
                .outFields(Lists.newArrayList("pk", "summary", "content"))
                .build();
        SearchResp searchResp = client.hybridSearch(request);
        List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
        List<SearchResp.SearchResult> queryResults;
        if (CollectionUtils.isEmpty(searchResults) || CollectionUtils.isEmpty(queryResults = searchResults.get(0))) {
            return Collections.emptyList();
        }
        return queryResults.stream()
                .filter(r -> r.getScore() > score)
                .map(r -> {
                    Map<String, Object> resultMap = r.getEntity();
                    return new ArticleMilvusSearchResponse(Long.valueOf(String.valueOf(resultMap.get("pk"))),
                                                           String.valueOf(resultMap.getOrDefault("summary", "")),
                                                           String.valueOf(resultMap.getOrDefault("content", "")));
                }).toList();
    }

    /**
     * 队列
     * @return
     */
    @Override
    protected ArrayBlockingQueue<SyncSearchEngineDTO> queue() {
        return queue;
    }

    /**
     * 真正的数据写入实现
     * @param data
     * @return
     */
    @Override
    public boolean doWriteCollection(SyncSearchEngineDTO data) {

        String content = data.getContent();
        // 正文html长富文本切片
        List<HtmlSplitter.SplitResult> contentObjs = HtmlSplitter.split(content, true);
        if (CollectionUtils.isEmpty(contentObjs)) {
            return false;
        }
        String pk = data.getPk();
        String sourceType = data.getSourceType();
        String summary = data.getSummary();

        List<String> contents = contentObjs.stream().map(HtmlSplitter.SplitResult::getContent).toList();
        List<String> summaries = contentObjs.stream().map(c ->
                                c.getTitle().concat("|").concat(data.getSummary())).toList();

        // 批量embedding
        List<float[]> contentVectors = super.embed(contents);
        List<float[]> summaryVectors = super.embed(summaries);
        List<JsonObject> vectors = new ArrayList<>();
        for (int i = 0; i < contentVectors.size(); i++) {
            vectors.add(
                    buildMilvusData(
                            pk,
                            sourceType,
                            summaries.get(i),
                            contents.get(i),
                            summaryVectors.get(i),
                            contentVectors.get(i)));
        }

        /**
         * content原文章内容切片后批量写
         */
        InsertReq request = InsertReq.builder()
                .collectionName(collectionName())
                .data(vectors)
                .build();
        log.info("文章向量数据入库>>>> pk:{}", data.getPk());
        InsertResp resp = client.insert(request);
        log.info("文章向量数据入库完成>>>> pk:{}， resp:{}", data.getPk(), JsonUtil.toJsonString(resp));
        return true;
    }

    private JsonObject buildMilvusData(String pk, String sourceType,
                                       String summary, String content,
                                       float[] summaryVector, float[] contentVector) {
        JsonObject vector = new JsonObject();
        /**
         * 原数据写入
         */
        vector.addProperty("pk", pk);
        vector.addProperty("source_type", sourceType);
        vector.addProperty("summary", summary);
        vector.addProperty("content", content);
        /**
         * 向量字段处理
         */
        Gson gson = new Gson();
        vector.add("summary_vector", gson.toJsonTree(summaryVector));
        vector.add("content_vector", gson.toJsonTree(contentVector));
        return vector;
    }
}
